archived/rl_gamerserver_ray/common/sagemaker_rl/mpi_launcher.py (168 lines of code) (raw):
import os
import shlex
import signal
import socket
import stat
import subprocess
import sys
import textwrap
import time
from contextlib import contextmanager
import sagemaker_containers
from retrying import retry
from sagemaker_containers import _logging
from sagemaker_containers.beta import framework
logger = _logging.get_logger()
# MPI files.
_MPI_SCRIPT = "/mpi_script.sh"
_MPI_IS_RUNNING = "/mpi_is_running"
_MPI_IS_FINISHED = "/mpi_is_finished"
_CHANGE_HOSTNAME_LIBRARY = "/libchangehostname.so"
def _change_hostname(current_host):
"""Compiles a shared library to correct the behavior of the gethostname system call,
which OpenMPI depends on.
Args:
current_host (str): name of the current host, such as algo-1, algo-2, etc.
"""
os.system("/change-hostname.sh {}".format(current_host))
def _start_ssh_daemon():
"""Starts the ssh deamon"""
subprocess.Popen(["/usr/sbin/sshd", "-D"])
def _setup_mpi_environment(env):
"""Setup MPI environment, i.e. executing change hostname scrip and starting ssh deamon."""
_change_hostname(env.current_host)
_start_ssh_daemon()
def _can_connect(host, port, s):
"""Checks if the connection to provided ``host`` and ``port`` is possible or not."""
try:
print("Testing connection to host {}".format(host))
s.connect((host, port))
s.close()
print("Can connect to host {}".format(host))
return True
except socket.error:
print("Can't connect to host {}".format(host))
return False
def _create_mpi_script(env, train_script, train_script_args):
"""Creates a MPI script with user provided information.
For distributed training: the 'master node' runs mpirun with this script,
'/mpi_script.sh'.
This script creates a file '/mpi_is_running' that worker nodes use to
determine whether training # (started by MPI from the master node) is still running.
Processes on worker nodes use # /mpi_is_finished file to determine when to exit.
Args:
env (TrainingEnv): an instance of the training environment.
"""
hyperparameters = framework.mapping.to_cmd_args(env.hyperparameters)
channels = framework.mapping.to_cmd_args(env.channel_input_dirs)
python_cmd = [sys.executable, train_script]
python_cmd.extend(train_script_args)
python_cmd.extend(hyperparameters)
python_cmd.extend(channels)
content = textwrap.dedent(
"""#!/usr/bin/env bash
touch /mpi_is_running
%s
EXIT_CODE=$?
touch /mpi_is_finished
exit ${EXIT_CODE}
"""
% " ".join(python_cmd)
)
with open(_MPI_SCRIPT, "w") as w:
w.write(content)
st = os.stat(_MPI_SCRIPT)
os.chmod(_MPI_SCRIPT, st.st_mode | stat.S_IEXEC)
class MPIMaster(object):
"""MPI Master
Args:
env (TrainingEnv): an instance of the training environment.
process_per_host (int): Number of processes per host to be executed by MPI
instance_type (str): Type of instance used for this job. It will be "local" for local mode. Its used to
perform different setup for local mode or sagemaker mode.
"""
def __init__(self, env, process_per_host, instance_type):
self.env = env
self.process_per_host = process_per_host
self.instance_type = instance_type
def _wait_for_worker_nodes_to_start_sshd(self, hosts, interval=1, timeout_in_seconds=180):
"""Wait for worker nodes to start their ssh deamon to allow MPI communication."""
with timeout(seconds=timeout_in_seconds):
while hosts:
print("hosts that aren't SSHable yet: {}".format(str(hosts)))
for host in hosts:
ssh_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if _can_connect(host, 22, ssh_socket):
print("Host: {} is sshable now.".format(host))
hosts.remove(host)
time.sleep(interval)
def _run_mpi_on_all_nodes(self):
"""Run MPI command to execute MPI_SCRIPT on all hosts."""
mpi_command = self._build_mpi_command()
cmd = shlex.split(mpi_command)
framework.logging.log_script_invocation(cmd, self.env.to_env_vars(), logger)
print("MPI Command: {}".format(mpi_command))
with open(_MPI_SCRIPT) as f:
print("Running user script:\n\n%s", f.read())
subprocess.check_call(cmd)
def _build_mpi_command(self):
"""Build MPI command."""
num_hosts = len(self.env.hosts)
num_processes = self.process_per_host * num_hosts
# By default, use one process per GPU, or one process per node (if training with CPU).
host_list = (
self.env.hosts
if self.process_per_host == 1
else [host + ":{}".format(self.process_per_host) for host in self.env.hosts]
)
print(
"Env Hosts: {} Hosts: {} process_per_hosts: {} num_processes: {}".format(
self.env.hosts, host_list, self.process_per_host, num_processes
)
)
credential_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]
interface_name = interface_name = self.env.network_interface_name
if self.instance_type == "local":
interface_name = "eth0"
print("network interface name:" + interface_name + " " + str(self.instance_type))
mpi_command = (
"mpirun --host {}".format(",".join(host_list))
+ " -np {} ".format(num_processes)
+ " --allow-run-as-root"
+ " --display-map"
+ " --tag-output"
+ " -mca btl_tcp_if_include {}".format(interface_name)
+ " -mca oob_tcp_if_include {}".format(interface_name)
+ " -x NCCL_SOCKET_IFNAME={}".format(interface_name)
+ " --mca plm_rsh_no_tree_spawn 1"
+ " -mca orte_abort_on_non_zero_status 1"
+ " -x NCCL_MIN_NRINGS=8 -x NCCL_DEBUG=INFO"
+ " -x LD_LIBRARY_PATH -x PATH"
+ " -x LD_PRELOAD={}".format(_CHANGE_HOSTNAME_LIBRARY)
)
for v in credential_vars:
if v in os.environ:
mpi_command += " -x {}".format(v)
for name, value in self.env.to_env_vars().items():
mpi_command += ' -x {}="{}"'.format(name, value)
mpi_command += " {}".format(_MPI_SCRIPT)
return mpi_command
def __call__(self):
self._wait_for_worker_nodes_to_start_sshd(self.env.hosts.copy())
self._run_mpi_on_all_nodes()
def is_master(self, hosts, current_host):
"""Checks if the current host is master or worker."""
print("Hosts: " + str(hosts) + " current host: " + str(current_host))
return current_host == sorted(list(hosts))[0]
class MPIWorker(object):
"""MPI Worker"""
@retry(
stop_max_delay=30000 * 1000, wait_fixed=1000, retry_on_result=lambda result: result is False
)
def _wait_for_mpi_to_start_running(self):
"""Wait and retry loop until the MPI training starts on this worker."""
return os.path.isfile(_MPI_IS_RUNNING)
@retry(wait_fixed=5000, retry_on_result=lambda result: result is False)
def _wait_until_mpi_stops_running(self):
"""Wait and retry loop until the MPI training is finished on this worker."""
return os.path.isfile(_MPI_IS_FINISHED)
def __call__(self, env):
current_host = env.current_host
print("Worker node {} is waiting for MPI to start training process".format(current_host))
self._wait_for_mpi_to_start_running()
print("MPI started training process on worker node {}".format(current_host))
self._wait_until_mpi_stops_running()
print("Training process started by MPI on worker node %s stopped", current_host)
class TimeoutError(Exception):
pass
@contextmanager
def timeout(seconds=0, minutes=0, hours=0):
"""
Add a signal-based timeout to any block of code.
If multiple time units are specified, they will be added together to determine time limit.
Usage:
with timeout(seconds=5):
my_slow_function(...)
Args:
- seconds: The time limit, in seconds.
- minutes: The time limit, in minutes.
- hours: The time limit, in hours.
"""
limit = seconds + 60 * minutes + 3600 * hours
def handler(signum, frame): # pylint: disable=W0613
raise TimeoutError("timed out after {} seconds".format(limit))
try:
signal.signal(signal.SIGALRM, handler)
signal.setitimer(signal.ITIMER_REAL, limit)
yield
finally:
signal.alarm(0)
class MPILauncher(object):
"""
MPI launcher, it can be used by algorithms supporting the MPI based distributed training.
Args:
train_script (str): Train script to executed by the ``MPILauncher``
train_script_args (list): List of args that are passed to the ``train_script`` to be executed by ``MPILauncher``
num_of_processes_per_host (int): Number of processes per host to be executed by MPI
instance_type (str): Type of instance used for this job. It will be "local" for local mode. Its used to perform
different setup for local mode or sagemaker mode.
"""
def __init__(
self, train_script, train_script_args=None, num_of_processes_per_host=1, instance_type=False
):
self._train_script = train_script
self._train_script_args = train_script_args
self._num_of_processes_per_host = num_of_processes_per_host
self._instance_type = instance_type
def mpi_run(self):
env = sagemaker_containers.training_env()
print("MPI requested with process per hosts: {}".format(self._num_of_processes_per_host))
_setup_mpi_environment(env)
_create_mpi_script(env, self._train_script, self._train_script_args)
mpi_master = MPIMaster(env, self._num_of_processes_per_host, self._instance_type)
if mpi_master.is_master(env.hosts, env.current_host):
print("Inside Master")
mpi_master()
else:
print("Inside Worker")
MPIWorker()(env)